{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ReLU Layers\n",
    "\n",
    "We can write a ReLU layer $z = \\max(Wx+b, 0)$ as the\n",
    "convex optimization problem\n",
    "\\begin{equation}\n",
    "\\begin{array}{ll}\n",
    "\\mbox{minimize} & \\|z-\\tilde Wx - b\\|_2^2 \\\\[.2cm]\n",
    "\\mbox{subject to} & z \\geq 0, \\\\\n",
    "& \\tilde W = W,\n",
    "\\end{array}\n",
    "\\label{eq:prob}\n",
    "\\end{equation}\n",
    "with variables $z$ and $\\tilde W$,\n",
    "and parameters $W$, $b$, and $x$.\n",
    "(Note that we have added an extra variable $\\tilde W$ so\n",
    "that the problem is DPP.)\n",
    "\n",
    "We can embed this problem into a TensorFlow Keras `Layer` and use it\n",
    "as a layer in a sequential neural network.\n",
    "We note that this example is purely illustrative;\n",
    "one can implement a ReLU layer much more efficiently\n",
    "by directly performing the matrix multiplication, vector addition,\n",
    "and then taking the positive part."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cvxpylayers.tensorflow import CvxpyLayer\n",
    "import cvxpy as cp\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "tf.keras.backend.set_floatx('float64')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReluLayer(layers.Layer):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(ReluLayer, self).__init__()\n",
    "        self.W = tf.Variable(1e-3 * tf.random.normal((output_dim, input_dim), dtype=tf.float64))\n",
    "        self.b = tf.Variable(1e-3 * tf.random.normal((output_dim,),  dtype=tf.float64))\n",
    "        z = cp.Variable(output_dim)\n",
    "        Wtilde = cp.Variable((output_dim, input_dim))\n",
    "        W = cp.Parameter((output_dim, input_dim))\n",
    "        b = cp.Parameter(output_dim)\n",
    "        x = cp.Parameter(input_dim)\n",
    "        problem = cp.Problem(cp.Minimize(\n",
    "            cp.sum_squares(z - Wtilde @ x - b)), [z >= 0, Wtilde == W])\n",
    "        self.cvxpy_layer = CvxpyLayer(problem, [W, b, x], [z])\n",
    "\n",
    "    def call(self, x):\n",
    "        if tf.rank(x) == 2:\n",
    "            # when x is batched, repeat W and b \n",
    "            batch_size = x.shape[0]\n",
    "            return self.cvxpy_layer(\n",
    "                tf.stack([self.W for _ in tf.range(batch_size)]),\n",
    "                tf.stack([self.b for _ in tf.range(batch_size)]), x)[0]\n",
    "        else:\n",
    "            return self.layer(self.W, self.b, x)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We generate synthetic data and create a network of two `ReluLayer`s followed by a linear layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.random.set_seed(0)\n",
    "model = tf.keras.Sequential([\n",
    "    ReluLayer(20, 20),\n",
    "    ReluLayer(20, 20),\n",
    "    tf.keras.layers.Dense(1, input_shape=(20,), dtype=tf.float64)\n",
    "])\n",
    "X = tf.random.normal((300, 20), dtype=tf.float64)\n",
    "Y = tf.random.normal((300, 1), dtype=tf.float64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can optimize the parameters inside the network using, for example, the ADAM optimizer.\n",
    "The code below solves 15000 convex optimization problems and calls backward 15000 times."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(1.1182059444179566, shape=(), dtype=float64)\n",
      "tf.Tensor(1.1135983945033263, shape=(), dtype=float64)\n",
      "tf.Tensor(1.1057307058385368, shape=(), dtype=float64)\n",
      "tf.Tensor(1.096769654269908, shape=(), dtype=float64)\n",
      "tf.Tensor(1.084459023932493, shape=(), dtype=float64)\n",
      "tf.Tensor(1.0690279563783676, shape=(), dtype=float64)\n",
      "tf.Tensor(1.0512619022795688, shape=(), dtype=float64)\n",
      "tf.Tensor(1.0322302881090637, shape=(), dtype=float64)\n",
      "tf.Tensor(1.0125819721396856, shape=(), dtype=float64)\n",
      "tf.Tensor(0.9913663215521406, shape=(), dtype=float64)\n",
      "tf.Tensor(0.9653676807090706, shape=(), dtype=float64)\n",
      "tf.Tensor(0.9346387377331861, shape=(), dtype=float64)\n",
      "tf.Tensor(0.901583417900799, shape=(), dtype=float64)\n",
      "tf.Tensor(0.8696330587760338, shape=(), dtype=float64)\n",
      "tf.Tensor(0.8358727832552605, shape=(), dtype=float64)\n",
      "tf.Tensor(0.8012632687553815, shape=(), dtype=float64)\n",
      "tf.Tensor(0.7656949960892767, shape=(), dtype=float64)\n",
      "tf.Tensor(0.732059435220954, shape=(), dtype=float64)\n",
      "tf.Tensor(0.7007173654028725, shape=(), dtype=float64)\n",
      "tf.Tensor(0.6717739637641655, shape=(), dtype=float64)\n",
      "tf.Tensor(0.6479549948680946, shape=(), dtype=float64)\n",
      "tf.Tensor(0.6260994111911655, shape=(), dtype=float64)\n",
      "tf.Tensor(0.6061758074104243, shape=(), dtype=float64)\n",
      "tf.Tensor(0.5910289888009993, shape=(), dtype=float64)\n",
      "tf.Tensor(0.5749367021103264, shape=(), dtype=float64)\n"
     ]
    }
   ],
   "source": [
    "opt = tf.keras.optimizers.Adam(1e-2)\n",
    "for _ in range(25):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss = (1 / X.shape[0]) * tf.math.reduce_sum((Y - model(X))**2)\n",
    "        print(loss)\n",
    "    grads = tape.gradient(loss, model.variables)\n",
    "    opt.apply_gradients(zip(grads, model.variables))"
   ]
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5rc1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
